#include "Model.h"
#include <iostream>
#include <assimp/Importer.hpp>
#include <assimp/postprocess.h>
#include <util.h>
#include <glm/glm.hpp>
#include <glm/gtx/euler_angles.hpp>
#include <glm/gtc/type_ptr.hpp>

#include <Singleton.h>

using namespace std;
using namespace Assimp;

static inline glm::vec3 vec3_cast(const aiVector3D& v)
{
    return {v.x, v.y, v.z};
}

static inline glm::vec2 vec2_cast(const aiVector2D& v)
{
    return {v.x, v.y};
}

static inline glm::quat quat_cast(const aiQuaternion& q)
{
    return {q.w, q.x, q.y, q.z};
}

static inline glm::mat4 mat4_cast(const aiMatrix4x4& mat)
{
    return glm::transpose(glm::make_mat4(&mat.a1));
}

static inline glm::mat3 mat3_cast(const aiMatrix3x3& mat)
{
    return glm::transpose(glm::make_mat3(&mat.a1));
}

Model::Model(const string &path)
{
    LoadModel(path);
}

void Model::Draw(const gl::IShader &shader) const
{
    for (const auto& i: meshes)
    {
        i.Draw(shader);
    }
}

void Model::Play(const string &animName, uint32_t frame)
{
    auto anim = cachedAnim.GetOrNull(animName);
    if (anim.has_value()) {
        auto frameId = frame % anim.value()->GetFrameSize();
        for (auto &i: meshes)
            i.SetAnimVec(anim.value()->GetFrameAnim(frameId));
        return;
    }
    cerr << "Can not find anim: " << animName << endl;
}

void Model::SetModelMatrix(const glm::mat4& mMat)
{
    for (auto& i: meshes)
    {
        i.SetModelMatrix(mMat);
    }
}

void Model::LoadModel(const string &path)
{
    filePath = path;
    Importer importer;
    const aiScene *scene = importer.ReadFile(path, aiProcess_Triangulate | aiProcess_FlipUVs);

    if(!scene || scene->mFlags & AI_SCENE_FLAGS_INCOMPLETE || !scene->mRootNode)
    {
        cerr << "ERROR::ASSIMP::" << importer.GetErrorString() << endl;
        return;
    }
    directory = path.substr(0, path.find_last_of('/'));

    LoadMeshes(scene);
    LoadAnimations(scene);
}

Texture Model::LoadMaterialTextures(const aiScene* scene, aiMaterial *material, aiTextureType type, const string &typeName)
{
    Texture r;
    for (int i = 0; i < material->GetTextureCount(type); ++i)
    {
        aiString path;
        material->GetTexture(type, i, &path);
        string imagePath = directory + "/" + path.C_Str();
        auto t = Singleton<Cache<string, Texture>>::Instance().GetOrNull(imagePath);
        if (t.has_value())
        {
            return t.value();
        }
        uint32_t texId = util::TextureFromFile(imagePath);
        if (texId == 0)
        {
            continue;
        }
        Texture texture;
        texture.Index = i;
        texture.TexId = texId;
        texture.Uid = imagePath;
        Singleton<Cache<string, Texture>>::Instance().AddOrUpdate(imagePath, texture);
        return texture;
    }
    return {"", 0, 0};
}

const aiNodeAnim* Model::FindNodeAnim(const aiAnimation* animation, const string& nodeName)
{
    for (uint32_t i = 0 ; i < animation->mNumChannels ; i++) {
        const aiNodeAnim* nodeAnim = animation->mChannels[i];
        if (string(nodeAnim->mNodeName.data) == nodeName)
            return nodeAnim;
    }
    return nullptr;
}

static double CalcP(double i, double start, double end)
{
    return (i - start) / (end - start);
}

static glm::mat4 CalcInterpolatedRotation(float animationTime, const aiNodeAnim* nodeAnim)
{
    auto findRotation = [](float animationTime, const aiNodeAnim* nodeAnim)
    {
        for (uint32_t i = 0 ; i < nodeAnim->mNumRotationKeys - 1 ; i++)
            if (animationTime < (float)nodeAnim->mRotationKeys[i + 1].mTime)
                return i;
        return nodeAnim->mNumRotationKeys - 1;
    };

    if (nodeAnim->mNumRotationKeys == 1)
        return {mat3_cast(nodeAnim->mRotationKeys[0].mValue.GetMatrix())};
    if (animationTime < nodeAnim->mRotationKeys[0].mTime)
        return {mat3_cast(nodeAnim->mRotationKeys[0].mValue.GetMatrix())};
    if (animationTime > nodeAnim->mRotationKeys[nodeAnim->mNumRotationKeys - 1].mTime)
        return {mat3_cast(nodeAnim->mRotationKeys[nodeAnim->mNumRotationKeys - 1].mValue.GetMatrix())};
    auto index = findRotation(animationTime, nodeAnim);
    auto p = (float)CalcP(animationTime, nodeAnim->mRotationKeys[index].mTime, nodeAnim->mRotationKeys[index + 1].mTime);
    auto& start = nodeAnim->mRotationKeys[index].mValue;
    auto& end = nodeAnim->mRotationKeys[index + 1].mValue;
    aiQuaternion out;
    aiQuaternion::Interpolate(out, start, end, p);
    return {mat3_cast(out.Normalize().GetMatrix())};
}

static glm::mat4 CalcInterpolatedPosition(float animationTime, const aiNodeAnim* nodeAnim)
{
    auto findPosition = [](float animationTime, const aiNodeAnim* nodeAnim)
    {
        for (uint32_t i = 0 ; i < nodeAnim->mNumPositionKeys - 1 ; i++)
            if (animationTime < (float)nodeAnim->mPositionKeys[i + 1].mTime)
                return i;
        return nodeAnim->mNumPositionKeys - 1;
    };

    if (nodeAnim->mNumPositionKeys == 1)
        return glm::translate(glm::mat4(1.0f), vec3_cast(nodeAnim->mPositionKeys[0].mValue));
    if (animationTime < nodeAnim->mPositionKeys[0].mTime)
        return glm::translate(glm::mat4(1.0f), vec3_cast(nodeAnim->mPositionKeys[0].mValue));
    if (animationTime > nodeAnim->mPositionKeys[nodeAnim->mNumPositionKeys - 1].mTime)
        return glm::translate(glm::mat4(1.0f), vec3_cast(nodeAnim->mPositionKeys[nodeAnim->mNumPositionKeys - 1].mValue));
    auto index = findPosition(animationTime, nodeAnim);
    auto p = (float)CalcP(animationTime, nodeAnim->mPositionKeys[index].mTime, nodeAnim->mPositionKeys[index + 1].mTime);
    auto& start = nodeAnim->mPositionKeys[index].mValue;
    auto& end = nodeAnim->mPositionKeys[index + 1].mValue;
    return glm::translate(glm::mat4(1.0f), vec3_cast(start + p * (end - start)));
}


static glm::mat4 CalcInterpolatedScaling(float animationTime, const aiNodeAnim* nodeAnim)
{
    auto findScaling = [](float animationTime, const aiNodeAnim* nodeAnim)
    {
        for (uint32_t i = 0 ; i < nodeAnim->mNumScalingKeys - 1 ; i++)
            if (animationTime < (float)nodeAnim->mScalingKeys[i + 1].mTime)
                return i;
        return nodeAnim->mNumScalingKeys - 1;
    };

    if (nodeAnim->mNumScalingKeys == 1)
        return glm::scale(glm::mat4(1.0f), vec3_cast(nodeAnim->mScalingKeys[0].mValue));
    if (animationTime < nodeAnim->mScalingKeys[0].mTime)
        return glm::scale(glm::mat4(1.0f), vec3_cast(nodeAnim->mScalingKeys[0].mValue));
    if (animationTime > nodeAnim->mScalingKeys[nodeAnim->mNumScalingKeys - 1].mTime)
        return glm::scale(glm::mat4(1.0f), vec3_cast(nodeAnim->mScalingKeys[nodeAnim->mNumScalingKeys - 1].mValue));
    auto index = findScaling(animationTime, nodeAnim);
    auto p = (float)CalcP(animationTime, nodeAnim->mScalingKeys[index].mTime, nodeAnim->mScalingKeys[index + 1].mTime);
    auto& start = nodeAnim->mScalingKeys[index].mValue;
    auto& end = nodeAnim->mScalingKeys[index + 1].mValue;
    return glm::scale(glm::mat4(1.0f), vec3_cast(start + p * (end - start)));
}

void Model::NodeAnimInner(aiAnimation* anim, float animationTime, const aiNode* node, const glm::mat4& parentTransform, std::vector<glm::mat4>& transforms)
{
    auto nodeName = string(node->mName.C_Str());
    auto nodeTransformation = mat4_cast(node->mTransformation);
    auto nodeAnim = FindNodeAnim(anim, nodeName);

    if (nodeAnim)
    {
        auto sT = CalcInterpolatedScaling(animationTime, nodeAnim);
        auto rT = CalcInterpolatedRotation(animationTime, nodeAnim);
        auto pT = CalcInterpolatedPosition(animationTime, nodeAnim);
        nodeTransformation = pT * rT * sT;
    }

    auto globalTransformation = parentTransform * nodeTransformation;

    if (boneNameIndex.find(nodeName) != boneNameIndex.end())
    {
        uint32_t BoneIndex = boneNameIndex[nodeName];
        transforms[BoneIndex] = animGlobalInverseT * globalTransformation * boneOffsetCache[BoneIndex];
    }

    for (uint32_t i = 0 ; i < node->mNumChildren ; i++) {
        NodeAnimInner(anim, animationTime, node->mChildren[i], globalTransformation, transforms);
    }
}

void Model::BoneTransform(const aiNode* root, aiAnimation* anim, float timeInMs, vector<glm::mat4>& transforms)
{
    glm::mat4 identity = glm::mat4(1.0f);
    transforms.resize(boneNameIndex.size(), glm::mat4(1.0f));

    auto ticksPerSecond = (float)(anim->mTicksPerSecond != 0 ? anim->mTicksPerSecond : 25.0f);
    auto timeInTicks = timeInMs * ticksPerSecond / 1000;
    auto animationTime = fmod(timeInTicks, (float)anim->mDuration);

    NodeAnimInner(anim, animationTime, root, identity, transforms);
}

void Model::LoadAnimations(const aiScene *scene)
{
    auto keyDuration = 33.0f;
    auto boneCounter = boneNameIndex.size();
    for (int i = 0; i < scene->mNumAnimations; ++i)
    {
        auto& anim = scene->mAnimations[i];
        auto animName = string(anim->mName.C_Str());
        auto ticksPerSecond = (float)(anim->mTicksPerSecond != 0 ? anim->mTicksPerSecond : 25.0f);
        auto duration = anim->mDuration * 1000 / ticksPerSecond;
        auto animFrameCount = int( duration / keyDuration) + 1;

        vector<shared_ptr<vector<glm::mat4>>> frames;
        frames.reserve(animFrameCount + 1);

        {
            animGlobalInverseT = mat4_cast(scene->mRootNode->mTransformation);
            animGlobalInverseT = glm::inverse(animGlobalInverseT);
        }

        cout << animName << ": " << animFrameCount << ", " << duration << ", " << keyDuration << endl;

        for (float j = 0; j < duration;)
        {
            auto ts = make_shared<vector<glm::mat4>>();
            BoneTransform(scene->mRootNode, anim, j, *ts);
            frames.emplace_back(ts);
            j += keyDuration;
        }
        cachedAnim.AddOrUpdate(animName, make_shared<ModelAnimation>(frames));
    }
}

void Model::LoadMeshes(const aiScene *scene)
{
    for (int i = 0; i < scene->mNumMeshes; ++i)
    {
        auto& mesh = scene->mMeshes[i];
        vector<Vertex> vertices;
        vertices.resize(mesh->mNumVertices);
        vector<uint32_t> indices;
        indices.reserve(mesh->mNumFaces * 3);
        Texture texture;

        for (int j = 0; j < mesh->mNumVertices; ++j)
        {
            auto& v = vertices[j];
            v.Position.x = mesh->mVertices[j].x;
            v.Position.y = mesh->mVertices[j].y;
            v.Position.z = mesh->mVertices[j].z;

            if (mesh->mNormals != nullptr)
            {
                v.Normal.x = mesh->mNormals[j].x;
                v.Normal.y = mesh->mNormals[j].y;
                v.Normal.z = mesh->mNormals[j].z;
            }

            if (mesh->mTextureCoords[0])
            {
                v.TexCoords.x = mesh->mTextureCoords[0][j].x;
                v.TexCoords.y = mesh->mTextureCoords[0][j].y;
            }
            else
            {
                v.TexCoords.x = 0;
                v.TexCoords.y = 0;
            }
        }

        for (int j = 0; j < mesh->mNumFaces; ++j)
        {
            const aiFace& face = mesh->mFaces[j];
            for (int k = 0; k < face.mNumIndices; ++k)
                indices.emplace_back(face.mIndices[k]);
        }

        if (mesh->mMaterialIndex >= 0)
        {
            aiMaterial* material = scene->mMaterials[mesh->mMaterialIndex];
            texture = LoadMaterialTextures(scene, material, aiTextureType_DIFFUSE, "base");
        }

        for (int j = 0; j < mesh->mNumBones; ++j)
        {
            auto& bone = mesh->mBones[j];
            auto boneName = string(bone->mName.C_Str());
            uint32_t boneId = 0;
            {
                auto iter = boneNameIndex.find(boneName);
                if (iter == boneNameIndex.end())
                {
                    boneId = boneNameIndex.size();
                    boneNameIndex.insert({boneName, boneId});
                }
                else
                {
                    boneId = iter->second;
                }
            }
            auto m = mat4_cast(bone->mOffsetMatrix);
            boneOffsetCache.insert({boneId, m});
            for (int k = 0; k < bone->mNumWeights; ++k)
            {
                auto& widget = bone->mWeights[k];
                auto& v = vertices[widget.mVertexId];
                for (int l = 0; l < MAX_BONE_LENGTH; ++l)
                {
                    if(v.BoneWidget[l] == 0.0f)
                    {
                        v.BoneId[l] = boneId;
                        v.BoneWidget[l] = widget.mWeight;
                        break;
                    }
                    if (l == MAX_BONE_LENGTH - 1)
                        cerr << "more bone sort needed!" << endl;
                }
            }
        }

        meshes.emplace_back(Mesh(vertices, indices, texture));
    }
}



